import os

import numpy as np
import torch

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import datetime
import json
import sys

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)
from generic.model_util import get_distrib_q_model_save_path, to_pt, get_maf_save_path
from player_ranking.player_evaluation_metric import run_risk_sensitive_player_evaluation
from agent import SportsAgent
from density_model.maf_model import update_maf, validate_maf
from evaluate.evaluate_distrib_rl import visualize_uncertainty_by_location, contextualized_empirical_risk_measure
from generic.data_util import read_args, load_config, HistoryScoreCache, ICEHOCKEY_ACTIONS, \
    divide_dataset_according2date, \
    Transition, read_feature_mean_scale


def train(args):
    config, debug_mode, log_file_path = load_config(args)
    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None

    if args.DEBUG_MODE:
        debug_mode = True
        debug_msg = 'debug_'
        # config['general']['model']['max_trace_length'] = 1
        # config['general']['training']['batch_size'] = 64
    else:
        debug_mode = False
        debug_msg = ''
    sanity_check_msg = None

    # flow_type = 'maf-split'  # 'maf', 'maf-split'
    # config['general']['maf']['flow_type'] = flow_type
    config['general']['checkpoint']['report_frequency'] = 1000

    if args.LEARN_MODE == 'location_ha':
        print('-' * 100, file=log_file, flush=True)
        print("*** Warning: Launching the sanity check. ***", file=log_file, flush=True)
        config['general']['model']['input_dim'] = len(ICEHOCKEY_ACTIONS) + 4
        sanity_check_msg = 'sanity_check_location_ha_'  # sanity_check_location_ha_, sanity_check_sd_md_tr_ha_
        debug_msg = sanity_check_msg + debug_msg
        print('-' * 100, file=log_file, flush=True)
    elif args.LEARN_MODE == 'normal':
        pass
    else:
        raise ValueError("Unknown learning mode {0}".format(args.LEARN_MODE))

    print(json.dumps(config, indent=4), file=log_file, flush=True)

    agent = SportsAgent(config=config, log_file=log_file)
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, _, _ = divide_dataset_according2date(all_data_files=all_files,
                                                         train_rate=agent.train_rate,
                                                         sports=agent.sports,
                                                         if_split=agent.apply_data_date_div,
                                                         )

    today = datetime.date.today()
    running_avg_maf_loss = HistoryScoreCache(capacity=500)
    running_avg_log_prob = HistoryScoreCache(capacity=500)
    episode_num = 0

    if args.CHECK_POINT is not None:
        date_label = args.CHECK_POINT
    else:
        date_label = datetime.datetime.now().strftime('%b-%d-%Y-%H:%M')

    maf_model_save_mother_dir = get_maf_save_path(agent=agent,
                                                  date_label=date_label,
                                                  debug_msg=debug_msg)
    if not os.path.exists(maf_model_save_mother_dir):
        os.mkdir(maf_model_save_mother_dir)
    # if args.CHECK_POINT is not None and os.path.isfile(load_from_path):

    # episode_num = 84
    # load_from_path = model_save_mother_dir + '/saved_model_{}'.format(episode_num)
    # if os.path.isfile(load_from_path):
    #     agent.load_maf(load_from=load_from_path,
    #                    log_file=log_file)
    # load_dqn_episode = 19000  # 11000, 19000, 36000, 42000, 45000
    # load_dqn_date_label = 'Nov-19-2021'  # Nov-19-2021, Dec-02-2021
    load_dqn_episode = args.LOAD_EPISODE
    load_dqn_date_label = args.LOAD_DATE
    dqn_model_save_mother_dir = get_distrib_q_model_save_path(agent, load_dqn_date_label, debug_msg)
    # if debug_mode:
    #     dqn_load_from_path = dqn_model_save_mother_dir.replace('debug_', '') + '/saved_model_{0}'.format(
    #         load_dqn_episode)
    # else:
    dqn_load_from_path = dqn_model_save_mother_dir + '/saved_model_{0}'.format(load_dqn_episode)
    print('Loading dqn from {0}'.format(dqn_load_from_path), file=log_file, flush=True)
    if os.path.isfile(dqn_load_from_path):
        _, _, _, _, _ = \
            agent.load_pretrained_model(load_from=dqn_load_from_path,
                                        log_file=log_file)

    if not os.path.exists(maf_model_save_mother_dir):
        os.mkdir(maf_model_save_mother_dir)

    # i = 0
    while episode_num <= agent.max_episode:
        start_file_idx = episode_num % len(training_files)
        for file_idx in range(start_file_idx, len(training_files)):
            file_name = training_files[file_idx]
            agent.maf_model.train()
            s_a_sequence, r_sequence = agent.load_sports_data(game_label=file_name,
                                                              sanity_check_msg=sanity_check_msg)
            pid_sequence = agent.load_player_id(game_label=file_name)
            if agent.apply_rnn:
                transition_all = agent.build_rnn_transitions(s_a_data=s_a_sequence,
                                                             r_data=r_sequence,
                                                             pid_sequence=pid_sequence)
            else:
                transition_all = agent.build_transitions(s_a_data=s_a_sequence,
                                                         r_data=r_sequence,
                                                         pid_sequence=pid_sequence)

            counter = 0
            end = False
            # transition_all = agent.select_data_by_action(transition_all=transition_all,
            #                                              sanity_check_msg=sanity_check_msg)

            if agent.maf_cond_value:
                output_game, _ = agent.compute_values_by_game(game_name=file_name,
                                                              sanity_check_msg=sanity_check_msg)
                values_cond = np.mean(output_game, axis=2)
                values_cond = to_pt(values_cond, enable_cuda=agent.enable_cuda, type='float')
            else:
                values_cond = None

            while not end:
                batch_data = agent.get_transition_batch(transition_all=transition_all, counter=counter)
                if agent.maf_cond_value:
                    batch_values_cond = values_cond[counter * agent.batch_size: (counter + 1) * agent.batch_size]

                loss, log_prob = update_maf(agent=agent,
                                            batch=batch_data,
                                            sanity_check_msg=sanity_check_msg,
                                            batch_values_cond=batch_values_cond)
                if (counter + 1) * agent.batch_size >= len(transition_all):
                    end = True
                counter += 1
                if loss is not None and log_prob is not None:
                    running_avg_maf_loss.push(loss.detach().cpu().item())
                    running_avg_log_prob.push(log_prob.detach().cpu().item())
            episode_num += 1

            if debug_mode or episode_num % agent.report_frequency < (episode_num - 1) % agent.report_frequency:
                return_correlations, maf_model_label = run_risk_sensitive_player_evaluation(agent=agent,
                                                                                            model_save_path=maf_model_save_mother_dir,
                                                                                            uncertainty_model='maf',
                                                                                            iteration=episode_num,
                                                                                            log_file=log_file,
                                                                                            sanity_check_msg=sanity_check_msg,
                                                                                            debug_mode=debug_mode,
                                                                                            debug_msg=debug_msg, )
                if not os.path.exists('./correlation_results/' + maf_model_label + '/read_qr_dqn_model_name.txt'):
                    with open('./correlation_results/' + maf_model_label + '/read_qr_dqn_model_name.txt', 'w') as rfile:
                        rfile.write(dqn_load_from_path)

                model_label = debug_msg + 'maf_' + agent.sports + maf_model_save_mother_dir.split('maf')[-1]
                contextualized_empirical_risk_measure(agent=agent,
                                                      model_label=model_label,
                                                      episode_num=episode_num,
                                                      debug_mode=debug_mode,
                                                      mode='test',
                                                      uncertainty_model='maf',
                                                      sports=agent.sports
                                                      )
                game_data = Transition(*zip(*transition_all))
                _, log_prob = validate_maf(agent=agent,
                                           batch=game_data,
                                           sanity_check_msg=sanity_check_msg,
                                           batch_values_cond=values_cond)
                print(log_prob[:20], file=log_file, flush=True)

                visualize_uncertainty_by_location(agent=agent,
                                                  debug_mode=debug_mode,
                                                  sanity_check_msg=sanity_check_msg,
                                                  episode_num=episode_num,
                                                  target_action='shot' if agent.sports == 'ice-hockey' else 'standard_shot',
                                                  mode='test',
                                                  uncertainty_model='maf',
                                                  model_label='heatmap' +
                                                              maf_model_save_mother_dir.split('saved')[-1],
                                                  if_plot_num=True,
                                                  log_file=log_file)
                test_loss = None
                test_log_prob = None
                print("Episode: {0}, Training Loss: {1}, Training Log Prob: {2}, "
                      "Testing Loss: {3}, Testing Log Prob: {4}.".format(
                    episode_num,
                    running_avg_maf_loss.get_avg(),
                    running_avg_log_prob.get_avg(),
                    test_loss,
                    test_log_prob
                ), file=log_file, flush=True)
                # print(corrcoef_string, file=log_file, flush=True)
                #
                agent.save_maf(save_to_path=maf_model_save_mother_dir + '/saved_model_{0}'.format(episode_num),
                               episode_no=episode_num,
                               avg_log_prob=running_avg_log_prob.get_avg(),
                               avg_loss=running_avg_maf_loss.get_avg(),
                               log_file=log_file)
                print("", file=log_file, flush=True)


def test(args):
    raise ValueError("please use the run_hockey_evaluate for testing")


if __name__ == "__main__":
    args = read_args()
    if int(args.TRAIN_FLAG):
        train(args)
    else:
        test(args)
